/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation.builder;

import com.facebook.presto.memory.LocalMemoryContext;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.aggregation.AccumulatorFactory;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;

import java.io.Closeable;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public class MergingHashAggregationBuilder
    implements Closeable
{
    private final List<AccumulatorFactory> accumulatorFactories;
    private final AggregationNode.Step step;
    private final int expectedGroups;
    private final ImmutableList<Integer> groupByPartialChannels;
    private final Optional<Integer> hashChannel;
    private final OperatorContext operatorContext;
    private final Iterator<Page> sortedPages;
    private InMemoryHashAggregationBuilder hashAggregationBuilder;
    private final List<Type> groupByTypes;
    private final LocalMemoryContext systemMemoryContext;
    private final long memorySizeBeforeSpill;
    private final int overwriteIntermediateChannelOffset;
    private final JoinCompiler joinCompiler;

    public MergingHashAggregationBuilder(
            List<AccumulatorFactory> accumulatorFactories,
            AggregationNode.Step step,
            int expectedGroups,
            List<Type> groupByTypes,
            Optional<Integer> hashChannel,
            OperatorContext operatorContext,
            Iterator<Page> sortedPages,
            LocalMemoryContext systemMemoryContext,
            long memorySizeBeforeSpill,
            int overwriteIntermediateChannelOffset,
            JoinCompiler joinCompiler)
    {
        ImmutableList.Builder<Integer> groupByPartialChannels = ImmutableList.builder();
        for (int i = 0; i < groupByTypes.size(); i++) {
            groupByPartialChannels.add(i);
        }

        this.accumulatorFactories = accumulatorFactories;
        this.step = AggregationNode.Step.partialInput(step);
        this.expectedGroups = expectedGroups;
        this.groupByPartialChannels = groupByPartialChannels.build();
        this.hashChannel = hashChannel.isPresent() ? Optional.of(groupByTypes.size()) : hashChannel;
        this.operatorContext = operatorContext;
        this.sortedPages = sortedPages;
        this.groupByTypes = groupByTypes;
        this.systemMemoryContext = systemMemoryContext;
        this.memorySizeBeforeSpill = memorySizeBeforeSpill;
        this.overwriteIntermediateChannelOffset = overwriteIntermediateChannelOffset;
        this.joinCompiler = joinCompiler;

        rebuildHashAggregationBuilder();
    }

    public Iterator<Page> buildResult()
    {
        return new Iterator<Page>() {
            private Iterator<Page> resultPages = Collections.emptyIterator();

            @Override
            public boolean hasNext()
            {
                return sortedPages.hasNext() || resultPages.hasNext();
            }

            @Override
            public Page next()
            {
                if (!resultPages.hasNext()) {
                    rebuildHashAggregationBuilder();
                    long memorySize = 0; // ensure that at least one merged page will be processed

                    // we can produce output  after every page, because sortedPages does not have
                    // hash values that span multiple pages (guaranteed by MergeHashSort)
                    while (sortedPages.hasNext() && !shouldProduceOutput(memorySize)) {
                        hashAggregationBuilder.processPage(sortedPages.next());
                        memorySize = hashAggregationBuilder.getSizeInMemory();
                        systemMemoryContext.setBytes(memorySize);
                    }
                    resultPages = hashAggregationBuilder.buildResult();
                }

                return resultPages.next();
            }
        };
    }

    @Override
    public void close()
    {
        hashAggregationBuilder.close();
    }

    private boolean shouldProduceOutput(long memorySize)
    {
        return (memorySizeBeforeSpill > 0 && memorySize > memorySizeBeforeSpill);
    }

    private void rebuildHashAggregationBuilder()
    {
        this.hashAggregationBuilder = new InMemoryHashAggregationBuilder(
                accumulatorFactories,
                step,
                expectedGroups,
                groupByTypes,
                groupByPartialChannels,
                hashChannel,
                operatorContext,
                DataSize.succinctBytes(0),
                Optional.of(overwriteIntermediateChannelOffset),
                joinCompiler);
    }
}
